# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Authors:
#  - Burkhard Ringlein <ngl@zurich.ibm.com>
#  - Jan van Lunteren <jvl@zurich.ibm.com>
#  - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
#  - Thomas Parnell <tpa@zurich.ibm.com>

import torch
import triton
import triton.language as tl

# from vllm.logger import init_logger

# logger = init_logger(__name__)


@triton.jit
def cdiv_fn(x, y):
    return (x + y - 1) // y


@triton.jit
def apply_softcap(S, x):
    Sdiv = S / x
    p1 = tl.exp(Sdiv)
    p2 = tl.exp(-Sdiv)
    return x * (p1 - p2) / (p1 + p2)


@triton.jit
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
                 BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr):
    left = 0
    right = num_seqs
    # 二分法查找内容 
    while left < right:
        mid = (left + right) // 2
        val = tl.load(query_start_len_ptr + mid)
        mid_val = val // BLOCK_Q + mid if use_q_block_mode else val

        if mid_val <= target_idx:
            left = mid + 1
        else:
            right = mid

    return left - 1


@triton.jit
def kernel_unified_attention_2d(
        output_ptr,  # [num_tokens, num_query_heads, head_size]
        query_ptr,  # [num_tokens, num_query_heads, head_size]
        key_cache_ptr,  # [num_blks, blk_size, num_kv_heads, head_size]
        value_cache_ptr,  # [num_blks, blk_size, num_kv_heads, head_size]
        block_tables_ptr,  # [num_seqs, max_num_blocks_per_seq]
        seq_lens_ptr,  # [num_seqs]
        alibi_slopes_ptr,  # [num_query_heads]
        scale,  # float32
        k_scale,  # float32
        v_scale,  # float32
        softcap,  # float32
        num_query_heads: tl.constexpr,  # int
        num_queries_per_kv: tl.constexpr,  # int
        block_table_stride: tl.int64,  # int
        query_stride_0: tl.int64,  # int
        query_stride_1: tl.int64,  # int, should be equal to head_size
        output_stride_0: tl.int64,  # int
        output_stride_1: tl.int64,  # int, should be equal to head_size
        BLOCK_SIZE: tl.constexpr,  # int
        HEAD_SIZE: tl.constexpr,  # int
        HEAD_SIZE_PADDED: tl.constexpr,  # int, must be power of 2
        USE_ALIBI_SLOPES: tl.constexpr,  # bool
        USE_SOFTCAP: tl.constexpr,  # bool
        SLIDING_WINDOW: tl.constexpr,  # int
        stride_k_cache_0: tl.int64,  # int
        stride_k_cache_1: tl.int64,  # int
        stride_k_cache_2: tl.int64,  # int
        stride_k_cache_3: tl.constexpr,  # int
        stride_v_cache_0: tl.int64,  # int
        stride_v_cache_1: tl.int64,  # int
        stride_v_cache_2: tl.int64,  # int
        stride_v_cache_3: tl.constexpr,  # int
        query_start_len_ptr,  # [num_seqs+1]
        BLOCK_Q: tl.constexpr,  # int
        num_seqs: tl.int32,
        BLOCK_M: tl.constexpr,  # int
):
    q_block_global_idx = tl.program_id(0)
    kv_head_idx = tl.program_id(1)

    # seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
    #                        BLOCK_Q, True)
    seq_idx = 0
    # q_block_start_idx = tl.load(query_start_len_ptr +
    #                             seq_idx) // BLOCK_Q + seq_idx

    # q_block_local_idx = q_block_global_idx - q_block_start_idx
    q_block_local_idx = 0

    # cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
    cur_batch_in_all_start_index = 0
    # cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
    cur_batch_in_all_stop_index = 0

    cur_batch_query_len = cur_batch_in_all_stop_index \
        - cur_batch_in_all_start_index

    if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
        return

    offs_m = tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, HEAD_SIZE_PADDED)
    query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

    query_offset_0 = cur_batch_in_all_start_index + query_pos
    query_offset_1 = kv_head_idx * num_queries_per_kv + \
        offs_m % num_queries_per_kv
    query_offset = (query_offset_0[:, None] * query_stride_0 +
                    query_offset_1[:, None] * query_stride_1 + offs_d[None, :])

    # dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
    # query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
    # query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
    dim_mask = offs_d < HEAD_SIZE
    query_mask_0 = query_pos < cur_batch_query_len
    query_mask_1 = query_offset_1 < num_query_heads

    # Q : (BLOCK_M, HEAD_SIZE_PADDED)
    # Q = tl.load(
    #     query_ptr + query_offset,
    #     mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
    #     other=0.0,
    # )

    block_table_offset = seq_idx * block_table_stride

    M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
    L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)

    # sequence len for this particular sequence
    # seq_len = tl.load(seq_lens_ptr + seq_idx)
    seq_len = 0
    # context length for this particular sequences
    context_len = seq_len - cur_batch_query_len

    # alibi slope for this head
    if USE_ALIBI_SLOPES:
        alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
                              mask=query_mask_1,
                              other=0.0)

    num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)

    # iterate through tiles
    # for j in range(0, num_blocks):

    #     physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)

    #     offs_n = tl.arange(0, BLOCK_SIZE)

    #     v_offset = (physical_block_idx * stride_v_cache_0 +
    #                 kv_head_idx * stride_v_cache_2 +
    #                 offs_d[None, :] * stride_v_cache_3 +
    #                 offs_n[:, None] * stride_v_cache_1)

    #     # ====== modify =======
    #     # k_offset = (physical_block_idx * stride_k_cache_0 +
    #     #             kv_head_idx * stride_k_cache_2 +
    #     #             offs_d[:, None] * stride_k_cache_3 +
    #     #             offs_n[None, :] * stride_k_cache_1)
    #     k_offset = (physical_block_idx * stride_k_cache_0 +
    #                 kv_head_idx * stride_k_cache_2 +
    #                 offs_d[None, :] * stride_k_cache_3 +
    #                 offs_n[:, None] * stride_k_cache_1)
    #     # ====== modify =======
        

    #     # K : (HEAD_SIZE, BLOCK_SIZE)
    #     # ======= modify =======
    #     K_load = tl.load(key_cache_ptr + k_offset,
    #                      mask=dim_mask[None, :],
    #                      other=0.0)
    #     # ======= modify =======

    #     # if K_load.dtype.is_fp8():
    #     #     if Q.dtype.is_fp8():
    #     #         K = K_load
    #     #     else:
    #     #         K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
    #     # else:
    #     #     K = K_load
    #     K = K_load

    #     # V : (BLOCK_SIZE, HEAD_SIZE)
    #     V_load = tl.load(value_cache_ptr + v_offset,
    #                      mask=dim_mask[None, :],
    #                      other=0.0)

    #     # if V_load.dtype.is_fp8():
    #     #     if Q.dtype.is_fp8():
    #     #         V = V_load
    #     #     else:
    #     #         V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
    #     # else:
    #     #     V = V_load
    #     V = V_load

    #     seq_offset = j * BLOCK_SIZE + offs_n

    #     # ============ modify =============
    #     bounary = context_len + query_pos[:, None]
    #     seq_mask = seq_offset[None, :] <= bounary
    #     # ============ modify =============

    #     # S : (BLOCK_M, BLOCK_SIZE)
    #     S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)

    #     # ======== modify =======
    #     K_trans = tl.trans(K)
    #     S += scale * tl.dot(Q, K_trans)
    #     # ======== modify =======

    #     if USE_SOFTCAP:
    #         S = apply_softcap(S, softcap)

    #     S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
    #                  S, float("-inf"))

    #     if SLIDING_WINDOW > 0:
    #         S = tl.where((context_len + query_pos[:, None] - seq_offset)
    #                      < SLIDING_WINDOW, S, float("-inf"))

    #     if USE_ALIBI_SLOPES:
    #         S += alibi_slope[:, None] * (seq_offset - context_len)

    #     # compute running maximum
    #     # m_j : (BLOCK_M,)
    #     m_j = tl.maximum(M, tl.max(S, 1))
    #     # m_j = tl.zeros([BLOCK_M, ], dtype=tl.float32)
    #     # For sliding window there's a chance the max is -inf due to masking of
    #     # the entire row. In this case we need to set m_j 0 to avoid NaN
    #     m_j = tl.where(m_j > float("-inf"), m_j, 0.0)

    #     # P : (BLOCK_M, BLOCK_SIZE)
    #     # 这句话卡住了
    #     # 之前都能很好的运行 自从版本更新后，就不行了
    #     # P = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32)
    #     P = tl.exp(S - m_j[:, None])
    #     # P = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32)

    #     # l_j : (BLOCK_M,)
    #     l_j = tl.sum(P, 1)
    #     # l_j = tl.zeros([BLOCK_M, ], dtype=tl.float32)

    #     # alpha : (BLOCK_M, )
    #     alpha = tl.exp(M - m_j)

    #     # acc : (BLOCK_M, HEAD_SIZE_PADDED)
    #     acc = acc * alpha[:, None]

    #     # update constants
    #     L = L * alpha + l_j
    #     M = m_j

    #     # acc : (BLOCK_M, HEAD_SIZE_PADDED)
    #     # P_cast = P.to(V.dtype)
    #     P = P.to(query_ptr.dtype.element_ty)
    #     acc += tl.dot(P, V)

    # # epilogue
    # acc = acc / L[:, None]

    # output_offset = (query_offset_0[:, None] * output_stride_0 +
    #                  query_offset_1[:, None] * output_stride_1 +
    #                  offs_d[None, :])

    # tl.store(
    #     output_ptr + output_offset,
    #     acc,
    #     mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
    # )


@triton.jit
def kernel_unified_attention_3d(
        segm_output_ptr,
        # [num_tokens, num_query_heads, num_segments, head_size]
        segm_max_ptr,  # [num_tokens, num_query_heads, num_segments]
        segm_expsum_ptr,  # [num_tokens, num_query_heads, num_segments]
        query_ptr,  # [num_tokens, num_query_heads, head_size]
        key_cache_ptr,  # [num_blks, num_kv_heads, head_size // x, blk_size, x]
        value_cache_ptr,  # [num_blks, num_kv_heads, head_size, blk_size]
        block_tables_ptr,  # [num_seqs, max_num_blocks_per_seq]
        seq_lens_ptr,  # [num_seqs]
        alibi_slopes_ptr,  # [num_query_heads]
        scale,  # float32
        k_scale,  # float32
        v_scale,  # float32
        softcap,  # float32
        num_query_heads: tl.constexpr,  # int
        num_queries_per_kv: tl.constexpr,  # int
        block_table_stride: tl.int64,  # int
        query_stride_0: tl.int64,  # int
        query_stride_1: tl.int64,  # int, should be equal to head_size
        BLOCK_SIZE: tl.constexpr,  # int
        HEAD_SIZE: tl.constexpr,  # int
        HEAD_SIZE_PADDED: tl.constexpr,  # int, must be power of 2
        USE_ALIBI_SLOPES: tl.constexpr,  # bool
        USE_SOFTCAP: tl.constexpr,  # bool
        SLIDING_WINDOW: tl.constexpr,  # int
        stride_k_cache_0: tl.int64,  # int
        stride_k_cache_1: tl.int64,  # int
        stride_k_cache_2: tl.int64,  # int
        stride_k_cache_3: tl.constexpr,  # int
        stride_v_cache_0: tl.int64,  # int
        stride_v_cache_1: tl.int64,  # int
        stride_v_cache_2: tl.int64,  # int
        stride_v_cache_3: tl.constexpr,  # int
        query_start_len_ptr,  # [num_seqs+1]
        BLOCK_Q: tl.constexpr,  # int
        num_seqs: tl.int32,
        BLOCK_M: tl.constexpr,  # int
        NUM_SEGMENTS_PER_SEQ: tl.constexpr,  # int
):
    q_block_global_idx = tl.program_id(0)
    kv_head_idx = tl.program_id(1)
    segm_idx = tl.program_id(2)

    # seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
    #                        BLOCK_Q, True)
    seq_idx = 0

    # q_block_start_idx = tl.load(query_start_len_ptr +
    #                             seq_idx) // BLOCK_Q + seq_idx

    # q_block_local_idx = q_block_global_idx - q_block_start_idx
    q_block_local_idx = 0

    # cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
    # cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
    cur_batch_in_all_start_index = 0
    cur_batch_in_all_stop_index = 0

    cur_batch_query_len = cur_batch_in_all_stop_index \
        - cur_batch_in_all_start_index

    if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
        return

    # sequence len for this particular sequence
    # seq_len = tl.load(seq_lens_ptr + seq_idx)
    seq_len = 0

    # number of segments for this particular sequence
    num_segments = NUM_SEGMENTS_PER_SEQ
    blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)

    if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
        return

    offs_m = tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, HEAD_SIZE_PADDED)

    query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

    query_offset_0 = cur_batch_in_all_start_index + query_pos
    query_offset_1 = kv_head_idx * num_queries_per_kv + \
        offs_m % num_queries_per_kv

    query_offset = (query_offset_0[:, None] * query_stride_0 +
                    query_offset_1[:, None] * query_stride_1 + offs_d[None, :])

    # ========== modify ===========
    # dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
    # query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
    # query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
    dim_mask = offs_d < HEAD_SIZE
    query_mask_0 = query_pos < cur_batch_query_len
    query_mask_1 = query_offset_1 < num_query_heads
    # ========== modify ===========

    # Q : (BLOCK_M, HEAD_SIZE_PADDED)
    # Q = tl.load(
    #     query_ptr + query_offset,
    #     mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
    #     other=0.0,
    # )

    block_table_offset = seq_idx * block_table_stride

    M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
    L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)

    # context length for this particular sequences
    context_len = seq_len - cur_batch_query_len

    # alibi slope for this head
    # if USE_ALIBI_SLOPES:
    #     alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
    #                           mask=query_mask_1,
    #                           other=0.0)

    num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)

    # iterate through tiles within current segment
    # for j in range(
    #         segm_idx * blocks_per_segment,
    #         min((segm_idx + 1) * blocks_per_segment, num_blocks),
    # ):
    #     physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)

    #     offs_n = tl.arange(0, BLOCK_SIZE)

    #     v_offset = (physical_block_idx * stride_v_cache_0 +
    #                 kv_head_idx * stride_v_cache_2 +
    #                 offs_d[None, :] * stride_v_cache_3 +
    #                 offs_n[:, None] * stride_v_cache_1)
        
    #     # ======== modify ==========
    #     # k_offset = (physical_block_idx * stride_k_cache_0 +
    #     #             kv_head_idx * stride_k_cache_2 +
    #     #             offs_d[:, None] * stride_k_cache_3 +
    #     #             offs_n[None, :] * stride_k_cache_1)
    #     k_offset = (physical_block_idx * stride_k_cache_0 +
    #                 kv_head_idx * stride_k_cache_2 +
    #                 offs_d[None, :] * stride_k_cache_3 +
    #                 offs_n[:, None] * stride_k_cache_1)
    #     # ======== modify ==========

    #     # K : (HEAD_SIZE, BLOCK_SIZE)
    #     # ======== modify =========
    #     K_load = tl.load(key_cache_ptr + k_offset,
    #                      mask=dim_mask[None, :],
    #                      other=0.0)
    #     # ======== modify =========

    #     # if K_load.dtype.is_fp8():
    #     #     if Q.dtype.is_fp8():
    #     #         K = K_load
    #     #     else:
    #     #         K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
    #     # else:
    #     #     K = K_load
    #     K = K_load

    #     # V : (BLOCK_SIZE, HEAD_SIZE)
    #     V_load = tl.load(value_cache_ptr + v_offset,
    #                      mask=dim_mask[None, :],
    #                      other=0.0)

    #     # if V_load.dtype.is_fp8():
    #     #     if Q.dtype.is_fp8():
    #     #         V = V_load
    #     #     else:
    #     #         V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
    #     # else:
    #     #     V = V_load
    #     V = V_load

    #     seq_offset = j * BLOCK_SIZE + offs_n

    #     # ====== modify ===========
    #     bounary = context_len + query_pos[:, None]
    #     seq_mask = seq_offset[None, :] <= bounary
    #     # ====== modify ===========

    #     # S : (BLOCK_M, BLOCK_SIZE)
    #     S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)

    #     # S += scale * tl.dot(Q, K)
    #     # ======== modify =======
    #     K_trans = tl.trans(K)
    #     S += scale * tl.dot(Q, K_trans)
    #     # ======== modify =======

    #     if USE_SOFTCAP:
    #         S = apply_softcap(S, softcap)

    #     S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
    #                  S, float("-inf"))

    #     if SLIDING_WINDOW > 0:
    #         S = tl.where((context_len + query_pos[:, None] - seq_offset)
    #                      < SLIDING_WINDOW, S, float("-inf"))

    #     if USE_ALIBI_SLOPES:
    #         S += alibi_slope[:, None] * (seq_offset - context_len)

    #     # compute running maximum
    #     # m_j : (BLOCK_M,)
    #     m_j = tl.maximum(M, tl.max(S, 1))
    #     # For sliding window there's a chance the max is -inf due to masking of
    #     # the entire row. In this case we need to set m_j 0 to avoid NaN
    #     m_j = tl.where(m_j > float("-inf"), m_j, 0.0)

    #     # P : (BLOCK_M, BLOCK_SIZE,)
    #     # 这句话卡住了
    #     P = tl.exp(S - m_j[:, None])

    #     # l_j : (BLOCK_M,)
    #     l_j = tl.sum(P, 1)

    #     # alpha : (BLOCK_M, )
    #     alpha = tl.exp(M - m_j)

    #     # acc : (BLOCK_M, HEAD_SIZE_PADDED)
    #     acc = acc * alpha[:, None]

    #     # update constants
    #     L = L * alpha + l_j
    #     M = m_j

    #     # acc : (BLOCK_M, HEAD_SIZE_PADDED)
    #     # acc += tl.dot(P.to(V.dtype), V)
    #     # P_cast = P.to(V.dtype)
    #     P = P.to(query_ptr.dtype.element_ty)
    #     acc += tl.dot(P, V)

    # segm_output_offset = (
    #     query_offset_0[:, None] *
    #     (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
    #     query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
    #     segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
    # tl.store(
    #     segm_output_ptr + segm_output_offset,
    #     acc,
    #     mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
    # )
    # segm_offset = (query_offset_0 *
    #                (num_query_heads * NUM_SEGMENTS_PER_SEQ) +
    #                query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
    # tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
    # tl.store(segm_expsum_ptr + segm_offset,
    #          L,
    #          mask=query_mask_0 & query_mask_1)


@triton.jit
def reduce_segments(
        output_ptr,  # [num_tokens, num_query_heads, head_size]
        segm_output_ptr,
        #[num_tokens, num_query_heads, max_num_segments, head_size]
        segm_max_ptr,  # [num_tokens, num_query_heads, max_num_segments]
        segm_expsum_ptr,  # [num_tokens, num_query_heads, max_num_segments]
        seq_lens_ptr,  # [num_seqs]
        num_seqs,  # int
        num_query_heads: tl.constexpr,  # int
        output_stride_0: tl.int64,  # int
        output_stride_1: tl.int64,  # int, should be equal to head_size
        block_table_stride: tl.int64,  # int
        BLOCK_SIZE: tl.constexpr,  # int
        HEAD_SIZE: tl.constexpr,  # int, must be power of 2
        HEAD_SIZE_PADDED: tl.constexpr,  # int, must be power of 2
        query_start_len_ptr,  # [num_seqs+1]
        BLOCK_Q: tl.constexpr,  # int
        NUM_SEGMENTS_PER_SEQ: tl.constexpr,  # int
):
    query_token_idx = tl.program_id(0)
    query_head_idx = tl.program_id(1)

    seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs,
                           BLOCK_Q, False)

    # sequence len for this particular sequence
    seq_len = tl.load(seq_lens_ptr + seq_idx)

    # number of segments for this particular sequence
    num_segments = NUM_SEGMENTS_PER_SEQ
    blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)

    # create masks for subsequent loads
    act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
    segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
        [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
    # dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
    #                     0).to(tl.int1)
    dim_mask = tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE
    # load segment maxima
    segm_offset = (query_token_idx*
                   (num_query_heads * NUM_SEGMENTS_PER_SEQ) +
                   query_head_idx * NUM_SEGMENTS_PER_SEQ +
                   tl.arange(0, NUM_SEGMENTS_PER_SEQ))
    segm_max = tl.load(segm_max_ptr + segm_offset,
                       mask=segm_mask,
                       other=float("-inf"))
    overall_max = tl.max(segm_max)

    # load and rescale segment exp sums
    segm_expsum = tl.load(segm_expsum_ptr + segm_offset,
                          mask=segm_mask,
                          other=0.0)
    segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
    overall_expsum = tl.sum(segm_expsum)

    # load, rescale, and add segment attention outputs
    segm_output_offset = (
        query_token_idx *
        (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
        query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
        tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED +
        tl.arange(0, HEAD_SIZE_PADDED)[None, :])
    segm_output = tl.load(
        segm_output_ptr + segm_output_offset,
        mask=segm_mask[:, None] & dim_mask[None, :],
        other=0.0,
    )
    segm_output *= tl.exp(segm_max - overall_max)[:, None]
    acc_sum = tl.sum(segm_output, 0)
    # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
    acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)

    # write result
    output_offset = (query_token_idx * output_stride_0 +
                     query_head_idx * output_stride_1 +
                     tl.arange(0, HEAD_SIZE_PADDED))
    tl.store(output_ptr + output_offset, acc, mask=dim_mask)


def unified_attention(
    q,
    k,
    v,
    out,
    cu_seqlens_q,
    max_seqlen_q,
    seqused_k,
    max_seqlen_k,
    softmax_scale,
    causal,
    window_size,
    block_table,
    softcap,
    q_descale,
    k_descale,
    v_descale,
    alibi_slopes=None,
):
    assert causal, "Only causal attention is supported"
    assert q_descale is None, "Q scales not supported"

    block_size = v.shape[1]
    assert q.element_size() >= 2 or block_size >= 32, \
        "Block size must be at least 32 for fp8"

    use_alibi_slopes = alibi_slopes is not None

    block_size = v.shape[1]
    num_seqs = len(seqused_k)
    num_query_heads = q.shape[1]
    num_kv_heads = k.shape[2]
    num_queries_per_kv = num_query_heads // num_kv_heads
    head_size = q.shape[2]

    BLOCK_M = 16
    BLOCK_Q = BLOCK_M // num_queries_per_kv

    # Ideally we would launch with kernel with:
    # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
    # However, it is slow to realize the query_lens on cpu.
    # Instead we use upper-bound:
    # \sum_i[ceil(query_len[i] / BLOCK_Q)]
    #   <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
    #    = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
    #   <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
    #    = floor(q.shape[0] / BLOCK_Q) + num_seqs
    total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs

    # if batch contains a prefill
    if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
        # kernel_unified_attention_2d[(
        #     total_num_q_blocks,
        #     num_kv_heads,
        # )](
        #     output_ptr=out,
        #     query_ptr=q,
        #     key_cache_ptr=k,
        #     value_cache_ptr=v,
        #     block_tables_ptr=block_table,
        #     seq_lens_ptr=seqused_k,
        #     alibi_slopes_ptr=alibi_slopes,
        #     scale=softmax_scale,
        #     k_scale=k_descale,
        #     v_scale=v_descale,
        #     softcap=softcap,
        #     num_query_heads=num_query_heads,
        #     num_queries_per_kv=num_queries_per_kv,
        #     block_table_stride=block_table.stride(0),
        #     query_stride_0=q.stride(0),
        #     query_stride_1=q.stride(1),
        #     output_stride_0=out.stride(0),
        #     output_stride_1=out.stride(1),
        #     BLOCK_SIZE=block_size,
        #     HEAD_SIZE=head_size,
        #     HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
        #     USE_ALIBI_SLOPES=use_alibi_slopes,
        #     USE_SOFTCAP=(softcap > 0),
        #     SLIDING_WINDOW=(1 + window_size[0]),
        #     stride_k_cache_0=k.stride(0),
        #     stride_k_cache_1=k.stride(1),
        #     stride_k_cache_2=k.stride(2),
        #     stride_k_cache_3=k.stride(3),
        #     stride_v_cache_0=v.stride(0),
        #     stride_v_cache_1=v.stride(1),
        #     stride_v_cache_2=v.stride(2),
        #     stride_v_cache_3=v.stride(3),
        #     query_start_len_ptr=cu_seqlens_q,
        #     BLOCK_Q=BLOCK_Q,
        #     num_seqs=num_seqs,
        #     BLOCK_M=BLOCK_M,
        # )
        pass
    else:
        # for initial version, NUM_SEGMENTS = 16 is chosen as a default
        # value that showed good performance in tests
        # NUM_SEGMENTS = 16

        # segm_output = torch.empty(
        #     q.shape[0],
        #     num_query_heads,
        #     NUM_SEGMENTS,
        #     triton.next_power_of_2(head_size),
        #     dtype=torch.float32,
        #     device=q.device,
        # )
        # segm_max = torch.empty(
        #     q.shape[0],
        #     num_query_heads,
        #     NUM_SEGMENTS,
        #     dtype=torch.float32,
        #     device=q.device,
        # )
        # segm_expsum = torch.empty(
        #     q.shape[0],
        #     num_query_heads,
        #     NUM_SEGMENTS,
        #     dtype=torch.float32,
        #     device=q.device,
        # )

        # kernel_unified_attention_3d[(
        #     total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
        #         segm_output_ptr=segm_output,
        #         segm_max_ptr=segm_max,
        #         segm_expsum_ptr=segm_expsum,
        #         query_ptr=q,
        #         key_cache_ptr=k,
        #         value_cache_ptr=v,
        #         block_tables_ptr=block_table,
        #         seq_lens_ptr=seqused_k,
        #         alibi_slopes_ptr=alibi_slopes,
        #         scale=softmax_scale,
        #         k_scale=k_descale,
        #         v_scale=v_descale,
        #         softcap=softcap,
        #         num_query_heads=num_query_heads,
        #         num_queries_per_kv=num_queries_per_kv,
        #         block_table_stride=block_table.stride(0),
        #         query_stride_0=q.stride(0),
        #         query_stride_1=q.stride(1),
        #         BLOCK_SIZE=block_size,
        #         HEAD_SIZE=head_size,
        #         HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
        #         USE_ALIBI_SLOPES=use_alibi_slopes,
        #         USE_SOFTCAP=(softcap > 0),
        #         SLIDING_WINDOW=(1 + window_size[0]),
        #         stride_k_cache_0=k.stride(0),
        #         stride_k_cache_1=k.stride(1),
        #         stride_k_cache_2=k.stride(2),
        #         stride_k_cache_3=k.stride(3),
        #         stride_v_cache_0=v.stride(0),
        #         stride_v_cache_1=v.stride(1),
        #         stride_v_cache_2=v.stride(2),
        #         stride_v_cache_3=v.stride(3),
        #         query_start_len_ptr=cu_seqlens_q,
        #         BLOCK_Q=BLOCK_Q,
        #         num_seqs=num_seqs,
        #         BLOCK_M=BLOCK_M,
        #         NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
        #     )

        # reduce_segments[(q.shape[0], num_query_heads)](
        #     output_ptr=out,
        #     segm_output_ptr=segm_output,
        #     segm_max_ptr=segm_max,
        #     segm_expsum_ptr=segm_expsum,
        #     seq_lens_ptr=seqused_k,
        #     num_seqs=num_seqs,
        #     num_query_heads=num_query_heads,
        #     output_stride_0=out.stride(0),
        #     output_stride_1=out.stride(1),
        #     block_table_stride=block_table.stride(0),
        #     BLOCK_SIZE=block_size,
        #     HEAD_SIZE=head_size,
        #     HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
        #     query_start_len_ptr=cu_seqlens_q,
        #     BLOCK_Q=BLOCK_Q,
        #     NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
        # )
        pass